import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import visual_behavior_glm
import visual_behavior_glm.GLM_params as glm_params
import visual_behavior_glm.GLM_analysis_tools as gat
import visual_behavior.data_access.loading as loading
%matplotlib inline
%load_ext autoreload
%autoreload 2
experiments_table = loading.get_filtered_ophys_experiment_table()
experiments_table.head()
# filepath = r"C:\Users\marinag\Dropbox\GLM\filtered_ophys_experiment_table.csv"
# experiments_table = pd.read_csv(filepath)
# experiments_table = experiments_table.set_index('ophys_experiment_id')
# experiments_table.head()
model_output_type = 'adj_fraction_change_from_full'
glm_version = '9a_L2_optimize_by_session'
rspm = gat.build_pivoted_results_summary(value_to_use=model_output_type, results_summary=None,
glm_version=glm_version, cutoff=0.01)
# # save
# save_dir = r'\\allen\programs\braintv\workgroups\nc-ophys\visual_behavior\GLM'
# rspm.to_hdf(os.path.join(save_dir, glm_version+'_'+model_output_type+'.h5'), key='df')
# # load
# load_dir = r'C:\Users\marinag\Dropbox\GLM'
# filepath = os.path.join(load_dir, glm_version+'_'+model_output_type+'.h5')
# rspm = pd.read_hdf(filepath, key='df')
rspm.head()
features_to_plot = [col for col in rspm.columns if col not in ['identifier','cre_line','session_type','equipment_name',
'session_id', 'imaging_depth','project_code','session_number','exposure_number']]
features_to_plot = [col for col in features_to_plot if col not in ['image0','image1','image2','image3',
'image4','image5','image6','image7',
'Full']]
features_to_plot = [col for col in features_to_plot if 'single' not in col]
features_to_plot
features_to_plot = [
'visual',
'all-images',
'omissions',
'behavioral',
'licking',
'licks',
'pupil',
'running',
'face_motion_energy',
'face_motion_PC_0',
'face_motion_PC_1',
'face_motion_PC_2',
'face_motion_PC_3',
'face_motion_PC_4',
'cognitive',
'image_expectation',
'hits',
'misses',
'false_alarms',
'correct_rejects',
'intercept',
'beh_model',
'model_bias',
'model_omissions1',
'model_task0',
'model_timing1D',
'task',
'time',
]
level_up_features = [
'visual',
'all-images',
'omissions',
'behavioral',
'task',
'beh_model',
]
len(features_to_plot)
def plot_feature_matrix(pivoted_results_summary, model_output_type, glm_version, features_to_plot=None, sort_by=None, ax=None, save_figure=False):
if ax is None:
figsize = (10,10)
fig, ax = plt.subplots(figsize=figsize)
if features_to_plot is None:
params = glm_params.define_kernels()
params = list(params.keys())
params.remove('each-image')
features_to_plot = list(np.sort(np.hstack((params,'all-images'))))
features = 'default_features'
else:
features = 'custom_features'
if sort_by is not None:
feature_matrix = pivoted_results_summary.sort_values(sort_by).reset_index()[features_to_plot]
else:
feature_matrix = pivoted_results_summary[features_to_plot]
sort_by = ''
ax = sns.heatmap(feature_matrix, vmin=-0.5, vmax=0, center=0, cmap='RdBu_r', ax=ax, cbar_kws={'label':model_output_type})
ax.set_ylabel('cells')
ax.set_title('GLM feature matrix')
if save_figure:
import visual_behavior.visualization.utils as ut
save_dir = os.path.join(loading.get_ophys_glm_dir(), 'v_'+glm_version, 'figures')
ut.save_figure(fig, figsize, save_dir, 'feature_matrix_heatmaps', 'glm_feature_matrix_'+glm_version+'_sort_by_'+sort_by+'_'+features)
return ax
rspm.head()
def plot_feature_matrix(pivoted_results_summary, value_to_use, glm_version, features_to_plot=None, sort_by=None, ax=None, save_figure=False):
"""
Plots a heatmap of GLM features from the pivoted_results_summary for a given glm_version
:param pivoted_results_summary: output of GLM_analysis_tools.build_pivoted_results_summary(), with value_to_use such as 'adj_fraction_change_from_full'
:param value_to_use: model output type used to create pivoted_results_summary, such as 'adj_fraction_change_from_full'
:param glm_version: string of GLM version, ex: '7_L2_optimize_by_session'
:param features_to_plot: list of GLM features to include in the plot. If None provided, will select defaults from GLM_params.define_kernels()
:param sort_by: GLM feature to sort by, such as 'omissions', or 'Full'
:param ax: axis to plot, if None provided, figure will be created
:param save_figure: if True, figure will be saved to a default location with figure title reflecting customized inputs to function
:return: figure axis
"""
if ax is None:
figsize = (10,10)
fig, ax = plt.subplots(figsize=figsize)
if features_to_plot is None:
import visual_behavior_glm.GLM_params as glm_params
params = glm_params.define_kernels()
params = list(params.keys())
params.remove('each-image')
features_to_plot = list(np.sort(np.hstack((params,'all-images'))))
features = 'default_features'
else:
features = 'custom_features'
if sort_by is not None:
feature_matrix = pivoted_results_summary.sort_values(sort_by).reset_index()[features_to_plot]
else:
feature_matrix = pivoted_results_summary[features_to_plot]
sort_by = ''
ax = sns.heatmap(feature_matrix, vmin=-0.5, vmax=0, center=0, cmap='RdBu_r', ax=ax, cbar_kws={'label':value_to_use})
ax.set_ylabel('cells')
ax.set_title('GLM feature matrix')
if save_figure:
import visual_behavior.visualization.utils as ut
save_dir = os.path.join(loading.get_ophys_glm_dir(), 'v_'+glm_version, 'figures')
ut.save_figure(fig, figsize, save_dir, 'feature_matrix_heatmaps', 'glm_feature_matrix_'+glm_version+'_sort_by_'+sort_by+'_'+features)
return ax
plot_feature_matrix(rspm, model_output_type, glm_version, features_to_plot=features_to_plot, sort_by='omissions', ax=None, save_figure=True)
plot_feature_matrix(rspm, model_output_type, glm_version, features_to_plot=features_to_plot, sort_by='all-images', ax=None, save_figure=True)
feature_matrix = rspm.sort_values('visual').reset_index()[level_up_features]
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(feature_matrix, vmin=-0.5, vmax=0, center=0, cmap='RdBu_r', ax=ax, cbar_kws={'label':model_output_type})
ax.set_ylabel('cells')
ax.set_title('GLM feature matrix')
df = rspm[rspm.project_code.isin(['VisualBehavior', 'VisualBehaviorTask1B'])]
cell_specimen_ids = df.cell_specimen_id.unique()
cre_lines = df.cre_line.unique()
matched_cells = [cell_specimen_id for cell_specimen_id in cell_specimen_ids if len(df[df.cell_specimen_id==cell_specimen_id])>=3]
session_numbers = np.sort(df.session_number.unique())
for cre_line in cre_lines:
nrows = 4
ncols = 8
fig, ax = plt.subplots(nrows,ncols, figsize=(20,20), sharey=True)
ax = ax.ravel()
cre_df = df[df.cell_specimen_id.isin(matched_cells)&(df.cre_line==cre_line)]
cre_df = cre_df.sort_values(by='Full', ascending=False) #sort by overall varience explained in full model
cre_matched_cells = cre_df.cell_specimen_id.unique()
for i,cell_specimen_id in enumerate(cre_matched_cells[:(nrows*ncols)]):
cdf = df[df.cell_specimen_id==cell_specimen_id]
ax[i] = sns.heatmap(data=cdf.set_index('session_number')[features_to_plot].sort_values(by='session_number').T,
vmin=-0.5, vmax=0.5, cmap='RdBu_r', square=True, ax=ax[i], cbar=False)
ax[i].set_xlim(-0.5,len(cdf)+0.5)
ax[i].set_xlabel('')
ax[i].set_ylim(-0.5,len(features_to_plot)+0.5)
ax[i].set_yticks(np.arange(0.5,len(features_to_plot)+0.5,1))
ax[i].set_yticklabels(features_to_plot)
ax[i].set_title('csid: '+str(cell_specimen_id)+'\n '+cdf.session_type.values[0]+'\ndepth: '+str(int(cdf.imaging_depth.mean())), fontsize=10)
# for y in np.arange(0,nrows*ncols,ncols):
# ax[y].set_xlabel('session #', fontsize=10)
# for x in range(nrows):
# ax[(nrows*ncols-x)-1].set_yticks(np.arange(len(features_to_plot)))
# ax[(nrows*ncols-x)-1].set_yticklabels(features_to_plot, rotation=0)
plt.suptitle(cre_line+' single cell feature vectors\ncolormap is '+model_output_type+' from -0.5 to 0.5', x=0.52, y=1.03, horizontalalignment='center')
fig.tight_layout()
df = rspm[rspm.project_code.isin(['VisualBehavior', 'VisualBehaviorTask1B'])]
fig, ax = plt.subplots(1,3,figsize=(15,7))
ax = ax.ravel()
for i,cre_line in enumerate(np.sort(cre_lines)):
tmp = df[df.cre_line==cre_line]
ax[i] = sns.heatmap(data=tmp.groupby(['session_number']).median()[features_to_plot].T,
ax=ax[i], square=True)
ax[i].set_xlim(-0.5,len(tmp.session_number.unique())+0.5)
ax[i].set_ylim(-0.5,len(features_to_plot)-0.5)
ax[i].set_yticks(np.arange(0.5,len(features_to_plot)+0.5,1))
ax[i].set_yticklabels(features_to_plot);
ax[i].set_title(cre_line)
fig.tight_layout()
df = rspm.copy()
fig, ax = plt.subplots(1,3,figsize=(15,7))
for i,cre_line in enumerate(np.sort(cre_lines)):
tmp = df[df.cre_line==cre_line]
ax[i] = sns.heatmap(data=tmp.groupby(['session_number']).median()[features_to_plot].T,
ax=ax[i], square=True)
ax[i].set_xlim(-0.5,len(tmp.session_number.unique())+0.5)
ax[i].set_ylim(-0.5,len(features_to_plot)-0.5)
ax[i].set_yticks(np.arange(0.5,len(features_to_plot)+0.5,1))
ax[i].set_yticklabels(features_to_plot);
ax[i].set_title(cre_line)
fig.tight_layout()
# df = rspm[rspm.project_code.isin(['VisualBehavior', 'VisualBehaviorTask1B'])]
# for c,cre_line in enumerate(cre_lines):
# tmp = df[(df.cre_line==cre_line)]
# containers = np.sort(tmp.container_id.unique())
# print(cre_line, len(containers))
# fig, ax = plt.subplots(4,5, figsize=(20,20))
# ax = ax.ravel()
# for i,container_id in enumerate(containers):
# cdf = tmp[(tmp.container_id==container_id)]
# ax[i] = sns.heatmap(data=cdf.groupby(['session_number']).median()[features_to_plot].T, ax=ax[i])
# ax[i].set_xlim(-0.5,len(cdf.session_number.unique())+0.5)
# ax[i].set_ylim(-0.5,len(features_to_plot)-0.5)
# ax[i].set_yticks(np.arange(0.5,len(features_to_plot)+0.5,1))
# ax[i].set_yticklabels(features_to_plot);
# ax[i].set_title(cre_line+' '+str(int(cdf.imaging_depth.mean()))+'\n'+str(container_id))
# fig.tight_layout()
session_numbers = np.sort(df.session_number.unique())
for cre_line in cre_lines:
nrows = 4
ncols = 8
fig, ax = plt.subplots(nrows,ncols, figsize=(20,20), sharey=True)
ax = ax.ravel()
cre_df = df[df.cell_specimen_id.isin(matched_cells)&(df.cre_line==cre_line)]
cre_df = cre_df.sort_values(by='Full', ascending=False)
cre_matched_cells = cre_df.cell_specimen_id.unique()
for i,cell_specimen_id in enumerate(cre_matched_cells[:(nrows*ncols)]):
cdf = df[df.cell_specimen_id==cell_specimen_id]
ax[i] = sns.heatmap(data=cdf.set_index('session_number')[features_to_plot].sort_values(by='session_number').T,
vmin=-0.5, vmax=0.5, cmap='RdBu_r', square=True, ax=ax[i], cbar=False)
ax[i].set_xlim(-0.5,len(cdf)-0.5)
ax[i].set_xlabel('')
ax[i].set_ylim(-0.5,len(features_to_plot)-0.5)
ax[i].set_yticks(np.arange(0,len(features_to_plot),1))
ax[i].set_yticklabels(features_to_plot)
ax[i].set_title('csid: '+str(cell_specimen_id)+'\n '+cdf.session_type.values[0]+'\ndepth: '+str(int(cdf.imaging_depth.mean())), fontsize=10)
# for y in np.arange(0,nrows*ncols,ncols):
# ax[y].set_xlabel('session #', fontsize=10)
# for x in range(nrows):
# ax[(nrows*ncols-x)-1].set_yticks(np.arange(len(features_to_plot)))
# ax[(nrows*ncols-x)-1].set_yticklabels(features_to_plot, rotation=0)
plt.suptitle(cre_line+' single cell feature vectors\ncolormap is '+model_output_type+' from -0.5 to 0.5', x=0.52, y=1.03, horizontalalignment='center')
fig.tight_layout()
fig, ax = plt.subplots(figsize=(4,4))
ax = sns.scatterplot(data=rspm, x='visual', y='behavioral', hue='cre_line', ax=ax)
fig, ax = plt.subplots(figsize=(4,4))
ax = sns.scatterplot(data=rspm, x='omissions', y='pupil', hue='cre_line', ax=ax)
fig, ax = plt.subplots(figsize=(4,4))
ax = sns.scatterplot(data=rspm, x='pupil', y='running', hue='cre_line', ax=ax)
colors = sns.color_palette()
colors = [colors[0], colors[2], colors[3]]
cre_lines = np.sort(rspm.cre_line.unique())
for metric in features_to_plot:
fig, ax = plt.subplots(figsize=(6,4))
sns.pointplot(data=rspm, x='session_number', y=metric, hue='cre_line', hue_order=cre_lines, palette=colors, ax=ax)
cre_lines = np.sort(rspm.cre_line.unique())
session_numbers = np.sort(rspm.session_number.unique())
cre_line = cre_lines[0]
session_number = session_numbers[0]
def get_colors_for_session_numbers():
reds = sns.color_palette('Reds_r', 6)[::2]
blues = sns.color_palette('Blues_r', 6)[::2]
return reds + blues
colors = get_colors_for_session_numbers()
# colors = [c[0], c[2], c[3], c[5]]
colors = get_colors_for_session_numbers()
fig, ax = plt.subplots(1,3, figsize=(16, 8))
for i,cre_line in enumerate(cre_lines):
for c,session_number in enumerate(session_numbers):
data = rspm[(rspm.cre_line==cre_line)&(rspm.session_number==session_number)][features_to_plot].melt()
ax[i] = sns.pointplot(data=data, x='value', y='variable', ax=ax[i], color=colors[c], )
ax[i].set_title(cre_line)
plt.legend(labels=session_numbers)
fig.tight_layout()
from sklearn.decomposition import PCA
rspm.head()
data = rspm[features_to_plot+['cre_line','identifier']].copy()
data = data.dropna()
n_features = len(features_to_plot)
n_components = len(features_to_plot)
pca = PCA(n_components=n_components)
pca_result = pca.fit_transform(data[features_to_plot].values)
data['pc1'] = pca_result[:,0]
data['pc2'] = pca_result[:,1]
data['pc3'] = pca_result[:,2]
print('Explained variation per principal component: {}'.format(pca.explained_variance_ratio_))
np.cumsum(pca.explained_variance_ratio_)
np.searchsorted(np.cumsum(pca.explained_variance_ratio_), .90)
np.searchsorted(np.cumsum(pca.explained_variance_ratio_), .95)
fig,ax=plt.subplots()
ax.plot(
np.arange(n_components),
pca.explained_variance_ratio_,
'o-k'
)
ax.set_xlabel('PC number')
ax.set_ylabel('variance explained')
ax.set_title('first 8 PCs explain >95% of the variance')
fig, ax = plt.subplots(figsize=(6,6))
ax = sns.heatmap(pca.components_, vmin=-1, vmax=1, cmap='RdBu_r', ax=ax, square=True,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'weight'})
ax.set_ylabel('principal components')
ax.set_xlabel('features')
# ax.set_title('principal axes in feature space \n(directions of maximum variance in the data)')
ax.set_ylim(0, n_components)
ax.set_xticks(np.arange(0.5, len(features_to_plot)+0.5, 1))
ax.set_xticklabels(features_to_plot, rotation=90);
pca.components_.shape
fig,ax=plt.subplots(figsize=(12,4))
N_PCs = 8
for PC in range(N_PCs):
ax.plot(pca.components_[PC,:])
ax.legend(np.arange(N_PCs), title='PC', bbox_to_anchor=(1,1))
ax.axhspan(1/np.sqrt(32), -1/np.sqrt(32), zorder=-np.inf, alpha=0.25)
ax.set_xticks(np.arange(len(features_to_plot)))
ax.set_xticklabels(features_to_plot, rotation=45, ha='right')
ax.set_ylabel('weight')
ax.set_ylim(-1.1,1.1)
fig.tight_layout()
fig,ax=plt.subplots(figsize=(12,4))
for PC in range(8,20):
ax.plot(pca.components_[PC,:])
ax.legend(np.arange(10,21), title='PC', bbox_to_anchor=(1,1))
ax.axhspan(1/np.sqrt(32), -1/np.sqrt(32), zorder=-np.inf, alpha=0.25)
ax.set_xticks(np.arange(len(features_to_plot)))
ax.set_xticklabels(features_to_plot, rotation=45, ha='right')
ax.set_ylabel('weight')
ax.set_ylim(-1.1,1.1)
fig.tight_layout()
fig, ax = plt.subplots(figsize=(6,6))
ax = sns.heatmap(pca.get_covariance(), vmin=-0.002, vmax=0.002, cmap='RdBu_r', ax=ax, square=True,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'covariance'})
ax.set_title('covariance matrix')
ax.set_ylim(0, n_features)
ax.set_xticks(np.arange(0.5, len(features_to_plot)+0.5, 1))
ax.set_xticklabels(features_to_plot, rotation=90);
ax.set_yticklabels(features_to_plot, rotation=0);
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(pca_result, cmap='RdBu_r', ax=ax,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'activation'})
ax.set_ylabel('cells')
ax.set_xlabel('PC')
ax.set_title('% variance explained in PC space')
ax.set_ylim(0, pca_result.shape[0])
ax.set_xlim(0, pca_result.shape[1])
ax.set_xticks(np.arange(0, pca_result.shape[1]));
# ax.set_xticklabels(features_to_plot, rotation=90);
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(pca_result[np.argsort(pca_result[:,0])], cmap='RdBu_r', ax=ax,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'activation'})
ax.set_ylabel('cells')
ax.set_xlabel('PC')
ax.set_title('% variance explained in PC space')
ax.set_ylim(0, pca_result.shape[0])
ax.set_xlim(0, pca_result.shape[1])
ax.set_xticks(np.arange(0, pca_result.shape[1]));
data.keys()
# pc1 and pc2 columns in rspm correspond to pca_results for those PCs - is this the 'score' per cell?
fig,ax = plt.subplots(1,2,figsize=(12,6))
ax[0] = sns.scatterplot(data=data, x="pc1", y="pc2", hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[0])
# ax[0].set_xlim(-5,10)
# ax[0].set_ylim(-5,10)
ax[1] = sns.scatterplot(data=data, x="pc2", y="pc3", hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[1])
# ax[1].set_xlim(-5,10)
# ax[1].set_ylim(-5,10)
pca_result_df = pd.DataFrame(pca_result, index=data.identifier)
pca_result_df['cre_line'] = data['cre_line'].values
# pc1 and pc2 columns in rspm correspond to pca_results for those PCs - is this the 'score' per cell?
PC1 = 0
PC2 = 1
PC3 = 3
PC4 = 4
fig,ax = plt.subplots(1, 3, figsize=(15,5))
ax = ax.ravel()
i=0
ax[i] = sns.scatterplot(data=pca_result_df, x=PC1, y=PC2, hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[i])
# ax[i].set_xlim(-100,100)
# ax[i].set_ylim(-100,100)
i+=1
ax[i] = sns.scatterplot(data=pca_result_df, x=PC2, y=PC3, hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[i])
# ax[i].set_xlim(-100,100)
# ax[i].set_ylim(-100,100)
i+=1
ax[i] = sns.scatterplot(data=pca_result_df, x=PC3, y=PC4, hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[i])
# ax[i].set_xlim(-100,100)
# ax[i].set_ylim(-100,100)
fig.tight_layout()
for cre_line in rspm.cre_line.unique():
data = rspm[features_to_plot+['cre_line','identifier']].copy()
data = data[data.cre_line==cre_line]
data = data.dropna()
n_features = len(features_to_plot)
n_components = len(features_to_plot)
pca = PCA(n_components=n_components)
pca_result = pca.fit_transform(data[features_to_plot].values)
data['pc1'] = pca_result[:,0]
data['pc2'] = pca_result[:,1]
data['pc3'] = pca_result[:,2]
fig, ax = plt.subplots(figsize=(6,6))
ax = sns.heatmap(pca.components_, vmin=-1, vmax=1, cmap='RdBu_r', ax=ax, square=True,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'weight'})
ax.set_ylabel('principal components')
ax.set_xlabel('features')
ax.set_title(cre_line)
ax.set_ylim(0, n_components)
ax.set_xticks(np.arange(0.5, len(features_to_plot)+0.5, 1))
ax.set_xticklabels(features_to_plot, rotation=90);